# 先验框解码调整预测框
import tensorflow as tf

#（一）对某一个输出特征层解码
def anchors_decode(feats, anchors, num_classes, inputs_shape):
    '''
    feats是某一个特征层的输出结果, 如shape=[b, 13, 13, 3*(5+num_classes)]
    anchors代表每个特征层, 每个网格会生成三个先验框[3,2]
    num_classes代表分类类别的数量
    inputs_shape代表输入图像的高宽[416,416]
    '''
    # 计算每个网格先验框的个数=3
    num_anchors = len(anchors)
    # 获得网格的高宽(h,w)=(13,13)
    grid_shape = feats.shape[1:3]

    #（1）构造网格，将图像划分成13*13个网格
    # 获得网格的x坐标 (13)==>[1,13,1,1]
    grid_x = tf.reshape(range(grid_shape[1]), shape=[1,-1,1,1])
    # 在y和anchors维度上复制扩张[1,13,1,1]==>[13,13,3,1]
    grid_x = tf.tile(grid_x, multiples=[grid_shape[0], 1, num_anchors, 1])

    # 同理获得网格的y坐标(13)==>[13,1,1,1]==>[13,13,3,1]
    grid_y = tf.tile(tf.reshape(range(grid_shape[0]), shape=[-1,1,1,1]), multiples=[1, grid_shape[1], num_anchors, 1])

    # 在通道维度上合并[13,13,3,2], 并且转换成tf.float32类型
    grid = tf.cast(tf.concat([grid_x, grid_y], axis=-1), dtype=tf.float32)
    
    #（2）调整先验框宽高shape, 13*13个网格，每个网格3个先验框，每个先验框(x,y)
    # [3,2]==>[1,1,3,2]
    anchors_tensor = tf.reshape(anchors, shape=[1,1,num_anchors,2])
    # 扩充[1,1,3,2]==>[13,13,3,2]
    anchors_tensor = tf.tile(anchors_tensor, multiples=[grid_shape[0], grid_shape[1], 1, 1])
    # 转换数据类型
    anchors_tensor = tf.cast(anchors_tensor, dtype=tf.float32)

    #（3）调整网络输出层的shape
    # [b,13,13,num_anchors*(5+num_classes)]==>[b,13,13,num_anchors,5+num_classes]
    feats = tf.reshape(feats, [-1, grid_shape[0], grid_shape[1], num_anchors, 5+num_classes])
    '''
    代表13*13个网格, 每个网格有3个先验框, 每个先验框有(5+num_classes)项信息
    其中, 5代表: 中心点坐标(x,y), 宽高(w,h), 置信度c
    num_classes: 检测框属于某个类别的条件概率, VOC数据集中等于20
    '''

    #（4）调整网格的先验框中心坐标和宽高
    # 归一化中心点坐标偏移量，把预测框中心点限制在该网格中
    anchors_xy = tf.nn.sigmoid(feats[..., :2])
    box_xy = anchors_xy + grid
    # 坐标归一化
    grid_shape_wh = tf.cast(grid_shape[::-1], feats.dtype)  #
    box_xy = box_xy / grid_shape_wh
    
    # 调整先验框宽高，得到预测框
    anchors_wh = tf.exp(feats[..., 2:4])
    box_wh = anchors_wh * anchors_tensor
    # 宽高度归一化
    inputs_shape_wh = tf.cast(inputs_shape[::-1], feats.dtype)
    box_wh = box_wh / inputs_shape_wh
    
    # 获得预测框的置信度和每个类别的条件概率
    box_conf = tf.nn.sigmoid(feats[..., 4:5])
    box_prob = tf.nn.sigmoid(feats[..., 5:])

    # 返回预测框的信息
    return box_xy, box_wh, box_conf, box_prob


#（二）将归一化后的预测框信息转换为对应的原图坐标
def correct_box(box_xy, box_wh, inputs_shape, image_shape):
    # 调整shape, y轴放前面方便宽高相乘
    box_yx = box_xy[..., ::-1]
    box_hw = box_wh[..., ::-1]
    
    # 调整输入图像的数据类型
    inputs_shape = tf.cast(inputs_shape, box_yx.dtype)
    image_shape = tf.cast(image_shape, box_hw.dtype)

    # 计算预测框在网格上的左上角和右下角坐标
    box_min = box_yx - (box_hw / 2.0)
    box_max = box_yx + (box_hw / 2.0)

    # 保存网格的左上和右下坐标
    boxes  = tf.concat([box_min[..., 0:1], box_min[..., 1:2], box_max[..., 0:1], box_max[..., 1:2]], axis=-1)
    # 获得原图上的左上和右下坐标
    boxes = boxes * tf.concat([image_shape, image_shape], axis=-1)
    
    return boxes


#（三）计算一个特征层的检测框坐标和类别的概率
def anchors_prediction(feats, anchors, num_classes, inputs_shape, image_shape):
    '''
    feats是某一个特征层的输出结果, 如shape=[b, 13, 13, 3*(5+num_classes)]
    anchors代表每个特征层, 每个网格会生成三个先验框[3,2]
    num_classes代表分类类别的数量
    image_shape:原始图像大小
    inputs_shape:网络输入的图片大小(h,w)=416*416
    '''
    # 获得某个特征层输出的预测框信息
    box_xy, box_wh, box_conf, box_prob = anchors_decode(feats, anchors, num_classes, inputs_shape)

    # 获得在原图像上的预测框
    boxes = correct_box(box_xy, box_wh, inputs_shape, image_shape)
    # [4]==>[n,4]
    boxes = tf.reshape(boxes, shape=[-1, 4])

    # 获得每个类别的真实概率
    box_scores = box_conf * box_prob
    # [num_classes]==>[n,num_classes]
    box_scores = tf.reshape(box_scores, shape=[-1,num_classes])

    # 返回每个框的坐标和每个类别的概率
    return boxes, box_scores


# 验证
if __name__ == '__main__':
 
    feat = tf.random.normal([4,13,13,75], mean=0, stddev=0.5)  # 构建输出特征图
    anchors = tf.constant([[142, 110],[192, 243],[459, 401]])  # 每个网格的先验框坐标
 
    # 返回调整后的预测框信息
    boxes, box_scores = anchors_prediction(feat, anchors, num_classes=20, 
                                            inputs_shape=(416,416), 
                                            image_shape=(1280,720))
    print(boxes)
    print(box_scores)
    
    
    # box_xy, box_wh, box_conf, box_prob = anchors_decode(feat, anchors, num_classes=20, inputs_shape=[416,416])
    
    # boxes = correct_box(box_xy, box_wh, inputs_shape=[416,416], image_shape=[1280,720])
    # print(boxes)




